{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [],
   "source": [
    "## import packages\n",
    "import torch\n",
    "import numpy as np\n",
    "import xarray as xr\n",
    "from model import ORCADLConfig, ORCADLModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 66, 128, 360]) torch.Size([1, 2, 128, 360])\n"
     ]
    }
   ],
   "source": [
    "## Download the contents in stat and ckpt folder\n",
    "\n",
    "## Prepare input data\n",
    "\n",
    "# salinity, potential temp, sea surface temp, zonal current, meridional current, sea surface height, zonal wind stress, meridional wind stress\n",
    "variables = ['salt', 'pottmp', 'sst', 'ucur', 'vcur', 'sshg', 'uflx', 'vflx'] \n",
    "\n",
    "# load mean and std\n",
    "stat = {\n",
    "    'mean': {v: np.load(f\"./stat/mean/{v}.npy\") for v in variables},\n",
    "    'std': {v: np.load(f\"./stat/std/{v}.npy\") for v in variables}\n",
    "}\n",
    "## Note: the unit for temperatue is °C, salinity is g/kg, current is m/s, ssh is m, wind stress is N/m²\n",
    "\n",
    "# load data\n",
    "month = 0  # the corresponding statistical values ​​for each month are different\n",
    "ocean_vars = []\n",
    "atmo_vars = []\n",
    "for v in variables[:-2]:\n",
    "    ds = xr.open_dataset(f\"./example_data/{v}.nc\")\n",
    "    # Important: the units of input data should be consistent with the units of the statistical values\n",
    "    normed_data = (ds[v].values - stat['mean'][v][month]) / stat['std'][v][month]\n",
    "    ocean_vars.append(normed_data if len(normed_data.shape) == 3 else normed_data[None])\n",
    "for v in variables[-2:]:\n",
    "    ds = xr.open_dataset(f\"./example_data/{v}.nc\")\n",
    "    normed_data = (ds[v].values - stat['mean'][v][month]) / stat['std'][v][month]\n",
    "    atmo_vars.append(normed_data[None])\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "ocean_vars = torch.from_numpy(np.nan_to_num(np.concatenate(ocean_vars, axis=0)))[None].float().to(device) # (1, 66, 128, 360)\n",
    "atmo_vars = torch.from_numpy(np.nan_to_num(np.concatenate(atmo_vars, axis=0)))[None].float().to(device) # (1, 2, 128, 360)\n",
    "\n",
    "print(ocean_vars.shape, atmo_vars.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 168,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Setup ORCA-DL\n",
    "model = ORCADLModel(ORCADLConfig.from_json_file('./model_config.json'))\n",
    "model.load_state_dict(torch.load('./ckpt/seed_1.bin', map_location='cpu'))\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 66, 128, 360])\n",
      "(1, 1, 128, 360)\n",
      "torch.Size([1, 6, 66, 128, 360])\n",
      "torch.Size([4, 6, 66, 128, 360])\n"
     ]
    }
   ],
   "source": [
    "## Run the model\n",
    "\n",
    "with torch.no_grad():\n",
    "    # single step\n",
    "    output = model(ocean_vars=ocean_vars, atmo_vars=atmo_vars, predict_time_steps=1)\n",
    "    print(output.preds.shape) # (1, 66, 128, 360)\n",
    "\n",
    "    # Post-process the output\n",
    "    preds = output.preds.detach().cpu().numpy()\n",
    "    # salinity, potential temp, sea surface temp, zonal current, meridional current, sea surface height\n",
    "    pred_all_variables = np.split(preds, model.split_chans, axis=1) # split by channels\n",
    "    # The pred_all_variables contains the prediction of all ocaen variables and the order is the same as the input ocean variables.\n",
    "\n",
    "    # inverse the normalization\n",
    "    pred_sst = pred_all_variables[2] * stat['std']['sst'][month+1] + stat['mean']['sst'][month+1]  # stat should be the prediction month\n",
    "    print(pred_sst.shape) # (1, 1, 128, 360)\n",
    "\n",
    "    # multi steps\n",
    "    steps = 6\n",
    "    output = model(ocean_vars=ocean_vars, atmo_vars=atmo_vars, predict_time_steps=steps)\n",
    "    print(output.preds.shape) # (1, steps, 66, 128, 360)\n",
    "\n",
    "    # batch input\n",
    "    batch_size = 4\n",
    "    ocean_vars = ocean_vars.repeat(batch_size, 1, 1, 1)\n",
    "    atmo_vars = atmo_vars.repeat(batch_size, 1, 1, 1)\n",
    "    output = model(ocean_vars=ocean_vars, atmo_vars=atmo_vars, predict_time_steps=steps)\n",
    "    print(output.preds.shape) # (batch_size, steps, 66, 128, 360)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
